{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/dmlc/gluon-cv/blob/onnx/scripts/onnx/notebooks/action-recognition/i3d_nl10_resnet50_v1_kinetics400.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade onnxruntime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "import onnxruntime as rt\n",
    "import urllib.request\n",
    "import os.path\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "(Optional) We use mxnet and gluoncv to read in video frames.\n",
    "\n",
    "Feel free to read in video frames your own way\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade mxnet gluoncv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def fetch_model():\n",
    "    if not os.path.isfile(\"i3d_nl10_resnet50_v1_kinetics400.onnx\"):\n",
    "        urllib.request.urlretrieve(\"https://apache-mxnet.s3-us-west-2.amazonaws.com/onnx/models/gluoncv-i3d_nl10_resnet50_v1_kinetics400-fce76386.onnx\", filename=\"i3d_nl10_resnet50_v1_kinetics400.onnx\")\n",
    "    return \"i3d_nl10_resnet50_v1_kinetics400.onnx\"\n",
    "    \n",
    "def prepare_video(video_path, input_shape):\n",
    "    from gluoncv.data.transforms import video\n",
    "    from gluoncv.utils.filesystem import try_import_decord\n",
    "    \n",
    "    decord = try_import_decord()\n",
    "    vr = decord.VideoReader(video_path)\n",
    "    frame_id_list = list(range(0, 64, 2))\n",
    "    num_frames = len(frame_id_list)\n",
    "    video_data = vr.get_batch(frame_id_list).asnumpy()\n",
    "    clip_input = [video_data[vid, :, :, :] for vid, _ in enumerate(frame_id_list)]\n",
    "    transform_fn = video.VideoGroupValTransform(size=input_shape[3], \n",
    "                                                    mean=[0.485, 0.456, 0.406], \n",
    "                                                    std=[0.229, 0.224, 0.225])\n",
    "    clip_input = transform_fn(clip_input)\n",
    "    clip_input = np.stack(clip_input, axis=0)\n",
    "    clip_input = clip_input.reshape((-1,) + (num_frames, 3, input_shape[3], input_shape[4]))\n",
    "    clip_input = np.transpose(clip_input, (0, 2, 1, 3, 4))\n",
    "    \n",
    "    return clip_input\n",
    "    \n",
    "def prepare_label():\n",
    "    from gluoncv.data import Kinetics400Attr\n",
    "    return Kinetics400Attr().classes\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Make sure to replace the video you want to use**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model = fetch_model()\n",
    "video_path = 'Your input'\n",
    "clip_input = prepare_video(video_path, (1, 3, 32, 224, 224))\n",
    "label = prepare_label()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Create a onnx inference session and get the input name\n",
    "onnx_session = rt.InferenceSession(model, None)\n",
    "input_name = onnx_session.get_inputs()[0].name    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "pred = onnx_session.run([], {input_name: clip_input.astype('float32')})[0]\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "(Optional) We use mxnet to process the result.\n",
    "\n",
    "Feel free to process the result your own way\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade mxnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import mxnet as mx\n",
    "\n",
    "pred = mx.nd.array(pred)\n",
    "topK = 5\n",
    "ind = mx.nd.topk(pred, k=topK)[0].astype('int')\n",
    "print('The input is classified to be')\n",
    "for i in range(topK):\n",
    "    print('\t[%s], with probability %.3f.'%\n",
    "          (label[ind[i].asscalar()], mx.nd.softmax(pred)[0][ind[i]].asscalar()))\n",
    "    "
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
